{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pip install -q transformers\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "\n",
    "checkpoint = \"output_dir/checkpoint-7800\"#\"bigscience/bloomz-3b\" #\"bigscience/bloom-7b1\"# \n",
    "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
    "model = AutoModelForCausalLM.from_pretrained(checkpoint).half().cuda()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PROMPT_DICT = {\n",
    "    \"prompt_input\": (\n",
    "        \"Below is an instruction that describes a task, paired with an input that provides further context. \"\n",
    "        \"Write a response that appropriately completes the request.\\n\\n\"\n",
    "        \"### Instruction:\\n{instruction}\\n\\n### Input:\\n{input}\\n\\n### Response:\"\n",
    "    ),\n",
    "    \"prompt_no_input\": (\n",
    "        \"Below is an instruction that describes a task. \"\n",
    "        \"Write a response that appropriately completes the request.\\n\\n\"\n",
    "        \"### Instruction:\\n{instruction}\\n\\n### Response:\"\n",
    "    ),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Optional\n",
    "def generate_input(instruction:Optional[str]= None, input_str:Optional[str] = None) -> str:\n",
    "    if input_str is None:\n",
    "        return PROMPT_DICT['prompt_no_input'].format_map({'instruction':instruction})\n",
    "    else:\n",
    "        return PROMPT_DICT['prompt_input'].format_map({'instruction':instruction, 'input':input_str})\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(10):\n",
    "    print(\"*\"*80)\n",
    "\n",
    "    inputs = tokenizer.encode(generate_input(instruction=\"写一篇关于水杯的文章\"), return_tensors=\"pt\")\n",
    "    inputs = inputs.to(model.device)\n",
    "    outputs = model.generate(inputs,num_beams=2,\n",
    "                            max_new_tokens=1024,\n",
    "                            do_sample=True, \n",
    "                            top_k=10,\n",
    "                            penalty_alpha=0.6,\n",
    "                            temperature=0.8,\n",
    "                            repetition_penalty=1.2)\n",
    "    print(tokenizer.decode(outputs[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mynet",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
